-
Notifications
You must be signed in to change notification settings - Fork 895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/Pass kwargs to the underlying models fit functions #2460
base: master
Are you sure you want to change the base?
Feat/Pass kwargs to the underlying models fit functions #2460
Conversation
(cherry picked from commit e80fe3fae4617033a6a4cde77afcd40c3072db33)
(cherry picked from commit 0a7b9fe8a1dc78fb9ef14a87cbd5e152e109eb2e)
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2460 +/- ##
==========================================
- Coverage 93.78% 93.77% -0.02%
==========================================
Files 139 139
Lines 14704 14689 -15
==========================================
- Hits 13790 13774 -16
- Misses 914 915 +1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a log @DavidKleindienst for this PR 🚀 It looks already very good!
I added a couple of suggestions, mainly about:
- we should try to allow fit
**kwargs
for all forecasting models - rename
fit_kwargs
tokwargs
- pass the kwargs to the underlying model's fit(), even if they don't support it. It might be a bit more informative to let the model raise an error, rather then swallowing the kwargs. And it would also automatically add support for dependencies that did not yet have additional parameters in
fit()
, but which might at some point add some parameters.
self, | ||
series: TimeSeries, | ||
future_covariates: Optional[TimeSeries] = None, | ||
**fit_kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would call it kwargs
since we already have this for regression models.
Also, should we add this to all forecasting models? E.g. ForecastingModel
, and all it's children (local models, ensemble models, regression models, torch models)? In some cases like TorchForecastingModel
, it will just not do anything, but at least there is a uniform API.
**fit_kwargs, | |
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main concern I have here is the EnsembleModel
In principle I see 3 possibilities of dealing with the kwargs
in this case
- Do nothing and ignore the kwargs
- That could be confusing because I might have an
EnsembleModel
of a set of closely related models which all support the same keyword argument. Then it would be pretty confusing for the User to have kwargs in the function signature but not have them passed to the models
- That could be confusing because I might have an
- Pass kwargs to all of the models
- That's pretty impractical because some of my models may accept a certain argument and others won't
- instead define
kwargs: list[dict]
which needs to correspond to the number of models.EnsembleModel.forecasting_model[0].fit
gets passedkwargs[0]
and so on- That option makes most sense to me, but of course we won't really have a unified API in this case
@dennisbader What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking through the code, there are some more models where I find it difficult to deal with the kwargs
, namely models that are mainly implemented within darts than wrapped from another library, such as Theta
, FFT
or the Baseline Models.
Let's take for example the Theta
model:
There are 1 or 2 calls to hw.SimpleExpSmoothing
inside the Theta.fit
function, so we could pass the kwargs
in those calls (hw.SimpleExpSmoothing
seems to allow for meaningful keyword arguments), but I know too little about the Theta
model to judge if that`s a meaningful expansion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main concern I have here is the
EnsembleModel
In principle I see 3 possibilities of dealing with thekwargs
in this case
Do nothing and ignore the kwargs
- That could be confusing because I might have an
EnsembleModel
of a set of closely related models which all support the same keyword argument. Then it would be pretty confusing for the User to have kwargs in the function signature but not have them passed to the modelsPass kwargs to all of the models
- That's pretty impractical because some of my models may accept a certain argument and others won't
instead define
kwargs: list[dict]
which needs to correspond to the number of models.EnsembleModel.forecasting_model[0].fit
gets passedkwargs[0]
and so on
- That option makes most sense to me, but of course we won't really have a unified API in this case
@dennisbader What do you think?
Yes, I see your point. We could do something like this:
**kwargs
should be passed only to the ensemble model itself (not the underlyingforecasting_models
). Currently this will only have an effect forRegressionEnsembleModel
, where the ensemble model is one of DartsRegressionModels
.- add a new parameter like your option 3 that expects a single dict or a list of dicts (
Optional[Union[Dict[str, Any]], List[Dict[str, Any]]]
) that are passed to the forecasting modelsfit()
.- if a single dict, pass the same to all models
- if a list of dicts, the length must match the number of forecasting models. Pass each dict to the correspoinding fc model.
What do you think?
- Allowed passing of kwargs to the `fit` functions of `Prophet` and `AutoARIMA` | ||
- 🔴 Restructured the signatures of `ExponentialSmoothing` `__init__` and `fit` functions so that the passing of additional parameters is consistent with other models | ||
- Keyword arguments to be passed to the underlying model's constructor must now be passed as keyword arguments instead of a `dict` to the `ExponentialSmoothing` constructor | ||
- Keyword arguments to be passed to the underlying model's `fit` function must now be passed to the `ExponentialSmoothing.fit` function instead of the constructor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once all suggestions have been addressed, we could formulate it as below :)
- Allowed passing of kwargs to the `fit` functions of `Prophet` and `AutoARIMA` | |
- 🔴 Restructured the signatures of `ExponentialSmoothing` `__init__` and `fit` functions so that the passing of additional parameters is consistent with other models | |
- Keyword arguments to be passed to the underlying model's constructor must now be passed as keyword arguments instead of a `dict` to the `ExponentialSmoothing` constructor | |
- Keyword arguments to be passed to the underlying model's `fit` function must now be passed to the `ExponentialSmoothing.fit` function instead of the constructor | |
- Improvements to `ForecastingModel` : [#2460](https://github.com/unit8co/darts/pull/2460) by [DavidKleindienst](https://github.com/DavidKleindienst). | |
- All forecasting models now support keyword arguments `**kwargs` when calling `fit()` that will be passed to the underlying model's fit function. | |
- 🔴 Changes to `ExponentialSmoothing` for a unified API: | |
- Removed `fit_kwargs` from `__init__()`. They must now be passed as keyword arguments to `fit()`. | |
- Parameters to be passed to the underlying model's constructor must now be passed as keyword arguments (instead of an explicit `kwargs` parameter). |
…_fit_functions # Conflicts: # CHANGELOG.md
Thank you very much for the review and suggestions @dennisbader! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... I'm actually rethinking whether allowing this type of customization for all models makes sense (sorry for that, I'm still open for discussion, also @madtoinou if you have an opinion on it).
Why?
For me it seems like the parameters for ExponentialSmoothing should actually be part of init, since those are kind of hyperparameters, that one should be able to tune. Since in (almost) all other models we follow the logic to have the hparams in the constructor, this would break a bit the "unified" design. I'm aware that this might seem unintuitive if you're used to statsmodels ExponentialSmoothing
itself, but from a Darts perspective, I think it's okay to have these fit parameters in the Darts model constructor (they would even seem more natural to me as explicit named parameters in the constructor).
The same applies to Prophet and ARIMA. I think here we could also add the fit_kwargs
to the constructor. Other libraries do it similarly (e.g. sktime here).
Regarding Theta and FFT, I would leave it as it is.
Regarding baseline models (if you mean the naive models), leave it as it is (there shouldn't be any more customization required, if I'm not missing anything).
Sorry for taking so long for replying @dennisbader, summer has been very busy for me. I think what you're saying make perfect sense - if we pass fit_kwargs in the constructor, this will allow easy compatibility with I just have one more question: |
Checklist before merging this PR:
Fixes #2438 .
Summary
Basically as described in #2438
FutureCovariatesLocalForecastingModel.fit
function to pass**kwargs
to each models._fit
function. There the arguments are either swallowed (if the underlying model'sfit
function does not support additional arguments) or passed to the underlying model's fit function (Prophet
,AutoArima
).ExponentialSmoothing
function signatures for consistency:ExponentialSmoothing
constructor (rather than as adict
)fit
function can now be specified as kwargs in theExponentialSmoothing.fit
functionOther Information